{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "7cc70088-51e3-4989-b9d8-b938d01bb152",
   "metadata": {},
   "source": [
    "# Structured Generation\n",
    "\n",
    "Using LLMs, we often need to use regex to parse the output of the LLM, like this one:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4efc233-6d47-4ea4-bd37-f2fcf2a53277",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "answer = \"\"\"The first 10 digits of pi (π) are as follows:\n",
    "\n",
    "3.1415926535\n",
    "\"\"\"\n",
    "\n",
    "regex = r\"([0-9]+)?\\.[0-9]+\"\n",
    "print(re.search(regex, answer))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6a85e612-e870-4c6a-ae17-5c4644f36425",
   "metadata": {},
   "source": [
    "In this practical, we will directly couple the regex with the LLM so that the output has exactly the right format!\n",
    "\n",
    "We follow the ideas of [Efficient Guided Generation for Large Language Models](https://arxiv.org/abs/2307.09702) by Brandon T. Willard, Rémi Louf.\n",
    "\n",
    "Regular expressions are a language that allows us to express patterns in text, and libraries can use a regex processor to find (sub)strings that match a given pattern. And if we’re able to match patterns, we’re also able to generate text that matches this pattern.\n",
    "\n",
    "The naive regex-guided generation process that loops over the entire vocabulary to find matches at each step:\n",
    "1. Start the generation with an empty prefix (the empty string).\n",
    "2. Concatenate every token of the vocabulary to the prefix. This gives you the set of all possible completions.\n",
    "3. For every possible completion use regex’s partial match feature to determine whether it can lead to a completion that matches the regex pattern. If that’s not the case, mask the token by setting its logit value to −∞.\n",
    "4. Use the masked logits to sample a new token.\n",
    "5. Concatenate the new token to the prefix and go back to (2).\n",
    "\n",
    "To generate a string that matches the regex `[0-9]+\\.[0-9]` and with a vocabulary `[\"a\", \".\", \".2\", \"1\"]`, starting with the empty string `““`, the following diagram illustrates the process we just described:\n",
    "![](https://cdn.prod.website-files.com/665725b00d910f65bec567fc/668c29d45780ee71a367c839_naive.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58ada73b-ca93-4383-aca6-fd1f9240e9b3",
   "metadata": {},
   "source": [
    "Another useful feature of regex processors is the ability to perform partial matches. A partial match is a string that matches up until its last character, thus one for which we could find completions that would have matched the pattern. This does not come with Python’s standard `re` library, but the [regex](https://github.com/mrabarnett/mrab-regex) library provides the functionality."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c129a15e-2a6e-4dd0-92cf-16f4d9048f85",
   "metadata": {},
   "outputs": [],
   "source": [
    "import regex as re\n",
    "\n",
    "regex = r\"([0-9]+)?\\.[0-9]+\"\n",
    "print(re.fullmatch(regex, '.2', partial=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64e51fa8-81c2-4e65-933a-bcfcaf5c6bc0",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(re.fullmatch(regex, '1.2', partial=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a41990eb-4625-4840-91fc-cf34096006a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(re.fullmatch(regex, '1.2a', partial=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f386a49-e925-4fa2-94c9-1f9ca62bda17",
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import regex as re\n",
    "import numpy as np\n",
    "\n",
    "from scipy.special import softmax\n",
    "\n",
    "np.random.seed(12349)\n",
    "\n",
    "logits = np.array([1., 1., 1., 1.])  # Random model with equal probabilities\n",
    "vocabulary = [\"a\", \".\", \".2\", \"1\"]\n",
    "\n",
    "regex = r\"([0-9]+)?\\.[0-9]+\"\n",
    "\n",
    "completion = \"\"\n",
    "for _ in range(7):\n",
    "\n",
    "    # Build the logit mask\n",
    "    #\n",
    "    # your code here\n",
    "    #\n",
    "    masked_logits = logits + mask\n",
    "\n",
    "    # Sample the next token\n",
    "    probs = softmax(masked_logits)\n",
    "    next_token_id = np.random.choice(len(vocabulary), p=probs)\n",
    "\n",
    "    completion += vocabulary[next_token_id]\n",
    "\n",
    "print(completion)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "07e95486-edbe-4df1-ae51-89709b149fd9",
   "metadata": {},
   "source": [
    "This will work for any regex pattern supported by the regex library, and does not depend on the tokenization shenanigans that LLM researchers make the rest of us suffer through. Isn’t this lovely?\n",
    "\n",
    "Bad news: this algorithm will explode in your face. The typical vocabulary $𝑉$ of a large language model contains roughly 50,000 (fifty thousand) tokens, which means that for each token you want to generate you will need to perform 50,000 partial matches. In a language like Python, the time spent performing partial matches will easily dominate the time it takes to generate the next token. While we’ve just solved the problem in theory, this solution is unusable in practice.\n",
    "\n",
    "## Deterministic Finite Automaton (DFA) a kind of Finite-State Machine (FSM)\n",
    "\n",
    "1. Start in state 0 with the full string;\n",
    "2. Pop the first character of the string. If it matches any transition rule, it moves to the corresponding state. Otherwise we terminate and reject the string;\n",
    "3. Iterate until the string is either rejected or you reach one of the DFA’s final (also called accept) states.\n",
    "\n",
    "What does the DFA look like for a given regular expression? Let’s consider the regular expression than previously, `([0-9])?+\\.[0-9]+`, and use the [interegular](https://github.com/MegaIng/interegular) Python library to translate it to its equivalent DFA representation:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84b7e6d9-d678-4d0b-9eae-f7833733a08c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import interegular\n",
    "\n",
    "regex = r\"([0-9]+)?\\.[0-9]+\"\n",
    "fsm = interegular.parse_pattern(regex).to_fsm()\n",
    "\n",
    "print(fsm)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13f5bef3-7538-4768-966d-e343da2e274c",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(fsm.alphabet)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50f20be3-0b5c-4555-9f1b-6028e5bb7359",
   "metadata": {},
   "outputs": [],
   "source": [
    "for (start, transitions) in fsm.map.items():\n",
    "    print(start, transitions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70635a5c-b249-40e3-b386-28592197615f",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(fsm.initial)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a737eea-c576-4031-94aa-38f183e8616d",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(fsm.finals)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1caba2f2-a3ed-486a-a71f-8c3168f6e0da",
   "metadata": {},
   "outputs": [],
   "source": [
    "char_2_str = [' [0-9]', 'any', ' . ']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48fb3290-bd00-4a5f-bfca-6c2f59f201c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import graphviz\n",
    "from IPython.display import display\n",
    "\n",
    "# Define the regex pattern\n",
    "regex = r\"([0-9]+)?\\.[0-9]+\"\n",
    "\n",
    "# Convert regex to a finite state machine (FSM)\n",
    "fsm = interegular.parse_pattern(regex).to_fsm()\n",
    "\n",
    "# Generate Graphviz DOT format representation\n",
    "dot = graphviz.Digraph(format=\"png\")\n",
    "\n",
    "# Add states to the graph\n",
    "for state in fsm.states:\n",
    "    shape = \"doublecircle\" if state in fsm.finals else \"circle\"\n",
    "    dot.node(str(state), shape=shape)\n",
    "\n",
    "# Add transitions to the graph\n",
    "for (start, transitions) in fsm.map.items():\n",
    "    for char, end in transitions.items():\n",
    "        dot.edge(str(start), str(end), label=char_2_str[char] if char is not None else \"ε\")\n",
    "\n",
    "display(dot)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4351b6b-c71c-4ab7-b446-e94a59c3d539",
   "metadata": {},
   "outputs": [],
   "source": [
    "fsm.map"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e1c828d-3dd9-4375-9eb4-3364b00bc613",
   "metadata": {},
   "source": [
    "First we need to figure out whether tokens in the vocabulary correspond to a valid path between states of the DFA. This is what the following function should do. If a token corresponds to a valid path, like “.2” we return the visited states. If it doesn’t, like ” The” we return None. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4fb53ac-1f82-4144-b9fa-01630e1b05f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def partial_match(state, token):\n",
    "    \"\"\"Partially match the token to the DFA starting from `state`.\n",
    "\n",
    "    We iterate over the token's symbols, and at each step transition to the \n",
    "    next state if we find a valid transition. \n",
    "    If there is a stage without a valid transision, we return None, otherwise\n",
    "    we return a tuple that contains the sequence of traversed states.\n",
    "\n",
    "    \"\"\"\n",
    "    \n",
    "    traversed_states = (state,)\n",
    "\n",
    "    # Iterate over the token's symbols, trying at each step to transition\n",
    "    # to a new DFA state.\n",
    "    #\n",
    "    # your code here\n",
    "    #\n",
    "    \n",
    "    return traversed_states"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50fccc57-7790-4e58-83ee-58d228711334",
   "metadata": {},
   "outputs": [],
   "source": [
    "token = \".21\"\n",
    "print(partial_match(0, token))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89109d5f-478a-4163-b017-cfa2938a2541",
   "metadata": {},
   "outputs": [],
   "source": [
    "token = \".21a\"\n",
    "print(partial_match(0, token))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c500642f-a8b9-473f-b475-11eba46bb11e",
   "metadata": {},
   "source": [
    "To build a map from the DFA’s states to tokens that correspond to valid completions, we need to loop over the states of the DFA, and for each state loop through the vocabulary to check whether tokens correspond to a valid path starting from this state:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e2de703-3602-4ea2-8f18-37755c61e1d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "vocabulary = [\"a\", \".\", \".2\", \"1\"]\n",
    "\n",
    "# Map from the DFA states to the tokens that correspond to a valid transition\n",
    "# from this state.\n",
    "states_to_vocab = defaultdict(set)\n",
    "#\n",
    "# your code here\n",
    "#"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b1372f2-65eb-45a5-a316-e00e0d351caf",
   "metadata": {},
   "outputs": [],
   "source": [
    "states_to_vocab"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15ca3bca-6622-430d-88b2-4cf88e1ef738",
   "metadata": {},
   "source": [
    "The generation can now proceed as follows:\n",
    "\n",
    "1. Start in state 0. Look for the tokens that lead to valid completions starting from 0 with `states_to_vocab[0]`.\n",
    "2. Mask the logits returned by the LLM so only these tokens can be sampled;\n",
    "3. Sample a new token using the logits. Look at the path that corresponds to `(state,token)`, the last state of the path corresponds to the next state of the DFA, `new_state`;\n",
    "4. Look for the tokens that lead to valid completions starting from `new_state` with `states_to_vocab[new_state]`.\n",
    "5. Go to (2) until the FSM is in one of its terminal states and this terminal state has no transition."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3614e2ee-d54a-497a-bd5d-ca9fc5211472",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(12349) # you should get the same result as before\n",
    "\n",
    "logits = np.array([1., 1., 1., 1.])  # same as before\n",
    "\n",
    "regex = r\"([0-9]+)?\\.[0-9]+\"\n",
    "\n",
    "completion = \"\"\n",
    "state = fsm.initial\n",
    "for _ in range(7):\n",
    "\n",
    "    # Build the logit mask\n",
    "    #\n",
    "    # your code here\n",
    "    #\n",
    "\n",
    "print(completion)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e67bd22-8aea-46fa-925d-11dac7ef7edf",
   "metadata": {},
   "source": [
    "To go further: [Coalescence: making LLM inference 5x faster](https://blog.dottxt.co/coalescence.html)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1cdd7057-4929-476d-bfb2-cd398185ca36",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dldiy",
   "language": "python",
   "name": "dldiy"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
